package cn.you.GenghisKhan.rpc.generic.dubbo.filter;
import cn.you.GenghisKhan.common.spring.SpringContextHolder;
import cn.you.GenghisKhan.common.utils.ReflectUtil;
import com.alibaba.dubbo.common.URL;
import com.alibaba.dubbo.common.extension.Activate;
import com.alibaba.dubbo.common.extension.ExtensionLoader;
import com.alibaba.dubbo.common.io.UnsafeByteArrayInputStream;
import com.alibaba.dubbo.common.io.UnsafeByteArrayOutputStream;
import com.alibaba.dubbo.common.serialize.Serialization;
import com.alibaba.dubbo.common.utils.PojoUtils;
import com.alibaba.dubbo.common.utils.ReflectUtils;
import com.alibaba.dubbo.common.utils.StringUtils;
import com.alibaba.dubbo.rpc.*;
import com.alibaba.dubbo.rpc.service.GenericException;
import com.alibaba.dubbo.rpc.support.ProtocolUtils;
import com.alibaba.fastjson.JSON;

import java.io.IOException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
 * author :wl
 * 核心filter
 */
@Activate(group = {"provider"})
public class RewriteGenericFilter implements Filter {

    @Override
    public Result invoke (Invoker<?> invoker, Invocation inv) throws RpcException {
        if (inv.getMethodName().equals("$invoke") && inv.getArguments() != null && inv.getArguments().length == 3 && !invoker.getUrl().getParameter("generic", false)) {
            String name = ((String) inv.getArguments()[0]).trim();
            String[] types = (String[]) ((String[]) inv.getArguments()[1]);
            Object[] args = (Object[]) ((Object[]) inv.getArguments()[2]);

            try {
                Method method = ReflectUtils.findMethodByMethodSignature(invoker.getInterface(), name, types);
                Class<?>[] params = method.getParameterTypes();
                if (args == null) {
                    args = new Object[params.length];
                }

                String generic = inv.getAttachment("generic");
                if (!StringUtils.isEmpty(generic) && !ProtocolUtils.isDefaultGenericSerialization(generic)) {
                    if (ProtocolUtils.isJavaGenericSerialization(generic)) {
                        for (int i = 0; i < args.length; ++i) {
                            if (byte[].class != args[i].getClass()) {
                                throw new RpcException((new StringBuilder(32)).append("Generic serialization [").append("nativejava").append("] only support message type ").append(byte[].class).append(" and your message type is ").append(args[i].getClass()).toString());
                            }
                            try {
                                UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream((byte[]) ((byte[]) args[i]));
                                args[i] = ((Serialization) ExtensionLoader.getExtensionLoader(Serialization.class).getExtension("nativejava")).deserialize((URL) null, is).readObject();
                            } catch (Exception var12) {
                                throw new RpcException("Deserialize argument [" + (i + 1) + "] failed.", var12);
                            }
                        }
                    }
                } else {
                    String isgeneric = inv.getAttachment("generic");
                    if (isgeneric != null && isgeneric.equals("true")) {
                        Map<String, Object> paramMap = new HashMap<>();
                        paramMap = (Map<String, Object>)args[0];
                        Object[] objects = new Object[params.length];
                        Object objClass = SpringContextHolder.getBean(invoker.getInterface());
                        List<String> parmNameList = ReflectUtil.getParamterName(objClass.getClass(), name);
//region
                        for (int i = 0; i < params.length; i++) {
                            String parmName = parmNameList.get(i);
                            Object object = paramMap.get(parmName);
                            if (object == null) {
                                objects[i] = null;
                            } else {
                                boolean bool = true;
                                try {
                                    bool = RewriteGenericFilter.isBaseDataType(params[i]);
                                } catch (Exception e) {
                                    e.printStackTrace();
                                }
                                if (bool) {
                                    objects[i] = object;
                                } else {
                                    objects[i] = JSON.parseObject(object.toString(), params[i]);
                                }

                            }
                        }
                        args = objects;
                        //endregion
                    }
                    args = PojoUtils.realize(args, params, method.getGenericParameterTypes());
                }

                Result result = invoker.invoke(new RpcInvocation(method, args, inv.getAttachments()));
                if (result.hasException() && !(result.getException() instanceof GenericException)) {
                    return new RpcResult(new GenericException(result.getException()));
                } else if (ProtocolUtils.isJavaGenericSerialization(generic)) {
                    try {
                        UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream(512);
                        ((Serialization) ExtensionLoader.getExtensionLoader(Serialization.class).getExtension("nativejava")).serialize((URL) null, os).writeObject(result.getValue());
                        return new RpcResult(os.toByteArray());
                    } catch (IOException var11) {
                        throw new RpcException("Serialize result failed.", var11);
                    }
                } else {
                    return new RpcResult(PojoUtils.generalize(result.getValue()));
                }
            } catch (NoSuchMethodException var13) {
                throw new RpcException(var13.getMessage(), var13);
            } catch (ClassNotFoundException var14) {
                throw new RpcException(var14.getMessage(), var14);
            }
        } else {
            return invoker.invoke(inv);
        }
    }

    /**
     * 判断一个类是否为基本数据类型。
     *
     * @param clazz 要判断的类。
     * @return true 表示为基本数据类型。
     */
    public static boolean isBaseDataType (Class clazz) throws Exception {
        return
                (
                        clazz.equals(String.class) ||
                                clazz.equals(Integer.class) ||
                                clazz.equals(Byte.class) ||
                                clazz.equals(Long.class) ||
                                clazz.equals(Double.class) ||
                                clazz.equals(Float.class) ||
                                clazz.equals(Character.class) ||
                                clazz.equals(Short.class) ||
                                clazz.equals(BigDecimal.class) ||
                                clazz.equals(BigInteger.class) ||
                                clazz.equals(Boolean.class) ||
                                clazz.equals(Date.class) ||
                                clazz.isPrimitive()
                );

    }
}